




import time
import datetime
from collections import deque
from quant.utils import InfluxClient


class InfluxRawLogger:
    def __init__(self, database='raw', delay=1):
        self.database = database
        self.delay = delay
        self.influx_client = None
        self._count = {}

        self._log_ts = 0

    def connect(self, host, port, user_name, password):
        if self.influx_client is not None:
            print('InluxRawLogger.write(), database already connected')
            return

        self.influx_client = InfluxClient(host, port, user_name, password, self.delay)

    def write(self, event, exchange, symbol, frequency, recv_time, raw):
        if self.influx_client is None:
            print('InluxRawLogger.write(), database not connected')
            return

        log_ts = round(recv_time, 6)
        if log_ts <= self._log_ts:
            log_ts = self._log_ts + 0.000001
            log_ts = round(log_ts, 6)
        self._log_ts = log_ts

        name = table_name(event, exchange, symbol, frequency)
        count = self._count.get(name, 0)
        self._count[name] = count + 1
        fields = {'raw': raw, 'count': count, 'recv_time': recv_time, 'log_ts': log_ts}
        self.influx_client.write(self.database, name, {}, fields, log_ts)

    def read(self, event, exchange, symbol, frequency, begin, end=None, limit=None):
        name = table_name(event, exchange, symbol, frequency)
        return self.read_with_name(name, begin, end, limit)

    def read_with_name(self, name, begin, end=None, limit=None, include_left=True):
        if self.influx_client is None:
            print('InluxRawLogger.write(), database not connected')
            return

        dt0 = datetime.datetime.utcfromtimestamp(begin)
        if include_left:
            sql = "SELECT * FROM \"{}\" WHERE time >= '{}'".format(name, dt0)
        else:
            sql = "SELECT * FROM \"{}\" WHERE time > '{}'".format(name, dt0)

        if end is not None:
            dt1 = datetime.datetime.utcfromtimestamp(end)
            sql += "AND time < '{}'".format(dt1)

        if limit is not None:
            sql += 'LIMIT {}'.format(limit)

        return self.influx_client.query(sql, self.database)


class SimDataAgent:
    def __init__(self, market):
        self.market = market
        self.logger: InfluxRawLogger = None

    def connect_database(self, host, port, user_name, password, db_name='raw', write_delay=1):
        if self.logger is not None:
            print('DataEngine.connect_database(), database already connected')
            return

        self.logger = InfluxRawLogger(db_name, write_delay)
        self.logger.connect(host, port, user_name, password)

    def start_collect(self):
        if self.logger is None:
            print('DataEngine.start_collect(), database not connected')
            return

        self.market.subscribe_raw(self.on_raw)

    def start_feed(self, begin, end):  # '2022-01-01T11:00:00Z'
        if self.logger is None:
            print('DataEngine.start_feed(), database not connected')
            return

        feed_assistant = FeedAssistant(self.market, self)
        feed_assistant.start_feed(begin, end)

    def start_feed_intv(self, begin, intv, pre, call):  # '2022-01-01T11:00:00Z'
        if self.logger is None:
            print('DataEngine.start_feed(), database not connected')
            return

        feed_assistant = FeedAssistant(self.market, self)
        feed_assistant.start_feed_intv(begin, intv, pre, call)

    def on_raw(self, event, exchange, symbol, frequency, recv_time, raw):
        self.logger.write(event, exchange, symbol, frequency, recv_time, raw)


class FeedAssistant:
    def __init__(self, markets, data_engine):
        self.markets = markets
        self.data_engine = data_engine

        self._name_to_fetcher = {}
        self._all_fetcher = []

        self._begin = None

    def start_feed(self, begin, end):  # '2022-01-01T11:00:00Z'
        begin = datetime.datetime.strptime(begin, '%Y-%m-%dT%H:%M:%SZ')
        self._begin = begin.timestamp()

        end = datetime.datetime.strptime(end, '%Y-%m-%dT%H:%M:%SZ')
        end = end.timestamp()

        self.markets.add_event_engine.subscribe_add_event(self._on_subscribe)

        while True:
            f = self._find_fetchers()
            data = f.pop_next()

            if data is None:
                print("finish")
                return

            recv = data[0]
            if recv > end:
                print("finish")
                break

            self.markets.feed_raw(f.event, f.exchange, f.symbol, f.frequency, *data)

    def start_feed_intv(self, begin, intv, pre, call):  # '2022-01-01T11:00:00Z'
        begin = datetime.datetime.strptime(begin, '%Y-%m-%dT%H:%M:%SZ')

        first = None
        last = begin.timestamp()
        self._begin = begin.timestamp()

        self.markets.add_event_engine.subscribe_add_event(self._on_subscribe)

        while True:
            f = self._find_fetchers()
            data = f.pop_next()

            if data is None:
                print("finish")
                return

            self.markets.feed_raw(f.event, f.exchange, f.symbol, f.frequency, *data)

            recv = data[0]
            if first is None:
                first = recv
                continue

            if recv < first + pre:
                continue

            if recv > last + intv:
                call()
                last = recv

    def _on_subscribe(self, event, exchange, symbol, frequency):
        name = table_name(event, exchange, symbol, frequency)
        if name in self._name_to_fetcher:
            return

        fetcher = SingleTableFetcher(event, exchange, symbol, frequency, self.data_engine.logger)
        fetcher.set_begin(self._begin)
        self._name_to_fetcher[name] = fetcher
        self._all_fetcher.append(fetcher)

    def _find_fetchers(self):
        found = min(self._all_fetcher, key=self._sort_fetcher_key)
        return found

    @staticmethod
    def _sort_fetcher_key(f):
        recv = f.next_recv_time()
        if recv is None:
            recv = float('inf')  # ############ 测试
        return recv


class SingleTableFetcher:
    def __init__(self, event, exchange, symbol, frequency, logger):
        self.event = event
        self.exchange = exchange
        self.symbol = symbol
        self.frequency = frequency

        self.table_name = table_name(event, exchange, symbol, frequency)
        self.logger = logger
        self._data = deque()  # [(time, count, recv_time, raw), ...]

        self._last_ts = None
        self._count = None
        self._finished = False

    def set_begin(self, begin):
        self._last_ts = begin

    def next_recv_time(self):
        if not self._data:
            self._fetch_next()

        if not self._data:
            return None

        return self._data[0][2]

    def pop_next(self):  # (recv_time, raw)
        if not self._data:
            self._fetch_next()

        if not self._data:
            return None

        result = self._data.popleft()
        count = result[1]
        self._check_count(count)

        return (result[2], result[3])

    def _fetch_next(self):
        if self._finished:
            return

        if self._last_ts is None:
            err = 'SingleTable not set_begin() yet'
            raise Exception(err)

        data = None
        for i in range(5):
            try:
                data = self._try_fetch()
                break
            except Exception as err:
                print('SingleTable._try_fetch() err:{}'.format(err))
                print('retry')

        if data is None:
            err = 'SingleTable fetch data fail'
            raise Exception(err)

        if len(data) == 0:
            self._finished = True
            return

        data = self._format_data(data)
        self._data = deque(data)

        dt = datetime.datetime.strptime(data[-1][0], '%Y-%m-%dT%H:%M:%S.%fZ')
        ts = dt.timestamp() - time.timezone
        self._last_ts = ts

    def _try_fetch(self):  # {'cols': {n: i, ...}  'values': {'time':t, 'count':c, 'recv_time:r, 'raw': r}}
        begin = self._last_ts
        result = self.logger.read_with_name(self.table_name, begin, limit=Ache_Count, include_left=False)
        result = result.raw['series']

        if len(result) == 0:
            return []

        assert len(result) == 1
        result = result[0]
        return result

    def _format_data(self, result):
        cols = result['columns']
        time_i = cols.index('time')
        count_i = cols.index('count')
        recv_time_i = cols.index('recv_time')
        raw_i = cols.index('raw')
        data = [(v[time_i], v[count_i], v[recv_time_i], v[raw_i]) for v in result['values']]
        return data  # [(time, count, recv_time, raw), ...]

    def _check_count(self, count):
        # print('check count:', count)
        if self._count is None:
            self._count = count
            return

        if count == 0:
            self._count = 0
            return

        if count != self._count + 1:
            err = 'Data count not ascending: {}, {} -> {}'.format(self.table_name, self._count, count)
            raise Exception(err)

        self._count = count


def table_name(event, exchange, symbol, frequency):
    return '{}_{}_{}_{}'.format(event, exchange, symbol, frequency)


Ache_Count = 1000


if __name__ == '__main__':
    import time
    from quant.markets import Markets, SimMarkets

    host = '47.243.236.229'

    def try_collect():
        mar = Markets()
        mar.add_market('Book', 'Binance', 'btc/usdt.swap')

        engine = SimDataAgent(mar)
        engine.connect_database(host, 8086, 'johnsquant', 'wo5892132156')
        engine.start_collect()

    def try_push():
        global Ache_Count
        Ache_Count = 1000

        mar = SimMarkets()
        mar.add_market('Book', 'Binance', 'btc/usdt.swap', 'Normal')
        mar.add_market('Book', 'Binance', 'eth/usdt.swap', 'Normal')

        def on_md(md):
            print(datetime.datetime.fromtimestamp(md.recv_time))
            print(md)
        mar.subscribe_all(on_md)

        engine = SimDataAgent(mar)
        engine.connect_database(host, 8086, 'johnsquant', 'wo5892132156')

        start = '2022-08-10T16:31:00Z'
        end = '2022-08-22T16:32:00Z'

        print("start")
        # engine.start_feed(start, end)
        engine.start_feed_intv(start, 1)

        input(':')

    try_push()


# todo markets推送中要做recv_time递增检测









